训练AI玩贪吃蛇(强化学习与Q 您所在的位置:网站首页 贪吃蛇 游戏规则 训练AI玩贪吃蛇(强化学习与Q

训练AI玩贪吃蛇(强化学习与Q

2023-08-25 01:45| 来源: 网络整理| 查看: 265

欢迎加入我们卧虎藏龙的python讨论qq群:729683466

●导 语 ●

AI近些年发展很火

不但可以下围棋

还可以玩各种各样的游戏

国外有人构建AI模型

然后训练出了非常厉害的拳击手

相信有一天

AI和机器人结合

肯定能训练出比泰森还强的高手

今天我们看一个简单的AI项目

AI自动玩贪吃蛇b45cab13b9385d34e9d36ce3c6be4e91.png

代码及相关资源获取

1:关注“python趣味爱好者”公众号,回复“AI玩贪吃蛇 ”获取源代码。

效果演示

游戏只训练了很短的时间,所以效果不太好,感兴趣的同学可以多训练几次。

41180fbf4af9f477490c6aeadff4502e.gif

往期精选

sdpython游戏:小蜜蜂历险记(一)

python小游戏之《小猫抓鱼》

 🛠

开发工具

python3.10

第三方库:pygame,collections,enum

讲解部分

一:原理介绍

        这里用到的核心知识是强化学习,这是一种无监督的算法,与有标签的监督学习不同,在深度学习中,监督学习的意思就是需要人工标注的正确答案,通过这些正确答案来训练神经网络,最终使神经网络可以准确模拟出下一步的操作。

        但是在无监督学习中,并没有人告诉程序,你走的哪一步是正确的,哪一步是错误的。程序自动不断试错,如果撞到了墙壁或者尾巴直接死亡,吃到食物获得奖励。这样神经网络才知道哪些动作正确,哪些不正确。

二:Q-learning算法

        在Q-learning算法自我训练的过程中,有5个比较重要的元素来评估当前状态的价值,首先是state,也就是当前的状态,对于贪吃蛇来说,这个状态很简单,就是蛇的位置,食物的位置,蛇移动方向等等。将这些信息组合在一起,形成一个向量,这就是当前状态,然后第二个元素就是蛇的选择action,一共4个方向,向前是无效的,向后是自杀,所以只有两个方向,向左向右(相对于蛇的移动方向来说)。这个选择就是action。随后,我们把state做出选择action后产生的新状态写作state1,这是第三个元素。如果state1是游戏结束(死亡),那么第四个元素done就是False,否则就是True。最后一个元素就是reward,也就是这个状态state1的奖励,如果吃到了食物就是奖励10分,否则就0分,如果死亡了,分数那就是-10分。

用一个图表示上述状态就是。

01762ac478a0405209f74a731c935064.png

Q-learning核心公式如下:

da8616d00175fc3ab78b3e1ad8bae1bb.png

在代码中实现公式如下:

Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

用于评估每个状态的value,以训练神经网络

整个Q-learning的代码如下:

class QTrainer: def __init__(self,model,lr,gamma): self.lr = lr self.gamma = gamma self.model = model self.optimer = optim.Adam(model.parameters(),lr = self.lr) self.criterion = nn.MSELoss() for i in self.model.parameters(): print(i.is_cuda) def train_step(self,state,action,reward,next_state,done): state = torch.tensor(state,dtype=torch.float) next_state = torch.tensor(next_state,dtype=torch.float) action = torch.tensor(action,dtype=torch.long) reward = torch.tensor(reward,dtype=torch.float) if(len(state.shape) == 1): # only one parameter to train , Hence convert to tuple of shape (1, x) #(1 , x) state = torch.unsqueeze(state,0) next_state = torch.unsqueeze(next_state,0) action = torch.unsqueeze(action,0) reward = torch.unsqueeze(reward,0) done = (done, ) # 1. Predicted Q value with current state pred = self.model(state) target = pred.clone() for idx in range(len(done)): Q_new = reward[idx] if not done[idx]: #Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])).cuda() Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx])) target[idx][torch.argmax(action).item()] = Q_new # 2. Q_new = reward + gamma * max(next_predicted Qvalue) -> only do this if not done # pred.clone() # preds[argmax(action)] = Q_new self.optimer.zero_grad() loss = self.criterion(target,pred) loss.backward() self.optimer.step()

三:神经网络

神经网络是深度学习的基本概念,这里不过多介绍,直接展示一下我们设计的神经网络代码,相信大家能看懂。

class Linear_QNet(nn.Module): def __init__(self,input_size,hidden_size,output_size): super().__init__() # self.linear1 = nn.Linear(input_size,hidden_size).cuda() # self.linear2 = nn.Linear(hidden_size,output_size).cuda() self.linear1 = nn.Linear(input_size,hidden_size) self.linear2 = nn.Linear(hidden_size,output_size) def forward(self, x): x = F.relu(self.linear1(x)) x = self.linear2(x) return x

作者|齐

编辑|齐

感谢大家观看

有钱的老板可打赏一下小编哦

扫描二维码

关注我们

QQ群:729683466

◰  

参考来源

代码来源:github



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

      专题文章
        CopyRight 2018-2019 实验室设备网 版权所有